人工神经元作为线性分类器的构建————感知机Perceptron

为了让人工神经元模型去实现一些具体的功能,我们在后续设计了一个激活函数,使得人工神经元具备二元线性分类的能力,这套新的模型被称为“感知机(perceptron)“

example example

下面给出感知机的训练算法:

输入:数据集 D {(x1, y1),(x2, y2),...,(xn, yn)},最大训练次数 E

  1. 初始化:w ← 0 , e ← 0, k ← 0
  2. For e to E do
  3. 感知机计算数据集 D
  4. i ← 0
  5. while i < n do
  6. ŷi ← -wTk xi
  7. if ŷi yi <= 0 then
  8. wk+1 ← wk + yixi

输出:wk

下面给出案例,以及可供参考的使用Java语言完成训练的有关代码:

example example
import java.util.Arrays;

        // 数据点类,包含特征向量 x 和标签 y
        class DataPoint {
            public double[] x; // 特征向量(3维)
            public int y;      // 标签(+1 或 -1)
        
            public DataPoint(double[] x, int y) {
                this.x = x;
                this.y = y;
            }
        }
        
        // 感知机类,涵盖初始化、预测、训练、获取四个方法
        class Perceptron {
            public double[] weights; // 权重向量(3维)
            public int iterations;   // 迭代次数
            // 初始化函数:初始权重(0,0,0)
            public Perceptron(int dimension) {
                this.weights = new double[dimension]; // 初始化权重为0
                Arrays.fill(this.weights, 0); // 初始权重为 (0, 0, 0)
                this.iterations = 0;
            }
            // 预测函数:使用 sign(w^T x) 预测标签
            public int predict(double[] x) {
                double dotProduct = 0;
                //用for循环计算向量点乘
                for (int i = 0; i < x.length; i++) {
                    dotProduct += weights[i] * x[i];
                }
                return (dotProduct > 0) ? 1 : -1;
            }
            // 训练函数:更新权重
            public void train(DataPoint point) {
                int prediction = predict(point.x);
                // 如果预测错误(y * (w^T x) <= 0),更新权重
                if (point.y * prediction <= 0) {
                    for (int i = 0; i < weights.length; i++) {
                        weights[i] += point.y * point.x[i]; // w = w + y * x
                    }
                    iterations++;
                    System.out.println("更新权重(迭代 " + iterations + "):
                    w = " + Arrays.toString(weights));
                } else {
                    System.out.println("正确分类,无需更新:w = " + 
                    Arrays.toString(weights));
                }
            }
            // 获取当前权重,保留两位小数
            public double[] getWeights() {
                double[] formattedWeights = new double[weights.length];
                for (int i = 0; i < weights.length; i++) {
                    formattedWeights[i] = Double.parseDouble
                    (String.format("%.2f", weights[i]));
                }
                return formattedWeights;
            }
        }
        
        public class PerceptronExample {
            public static void main(String[] args) {
                // 定义数据集(根据表格数据)
                DataPoint[] data = new DataPoint[12];
                data[0] = new DataPoint(new double[]{3.7, -10.5, 27.49}, 1);
                data[1] = new DataPoint(new double[]{4.8, 0.13, 21.75}, -1);
                data[2] = new DataPoint(new double[]{4.9, 1.25, 22.31}, -1);
                data[3] = new DataPoint(new double[]{4.8, 0.13, 21.75}, -1);
                data[4] = new DataPoint(new double[]{3.7, -10.5, 27.49}, 1);
                data[5] = new DataPoint(new double[]{4.9, 1.25, 22.31}, -1);
                data[6] = new DataPoint(new double[]{3.7, -10.5, 27.49}, 1);
                data[7] = new DataPoint(new double[]{4.9, 1.25, 22.31}, -1);
                data[8] = new DataPoint(new double[]{4.8, 0.13, 21.75}, -1);
                data[9] = new DataPoint(new double[]{3.7, -10.5, 27.49}, 1);
                data[10] = new DataPoint(new double[]{4.9, 1.25, 22.31}, -1);
                data[11] = new DataPoint(new double[]{4.8, 0.13, 21.75}, -1);
        
                // 创建感知机实例(3维特征)
                Perceptron perceptron = new Perceptron(3);
                // 训练模型:遍历数据集
                System.out.println("开始训练,初始权重:w = " + 
                Arrays.toString(perceptron.getWeights()));
                for (int i = 0; i < data.length; i++) {
                    System.out.println("\n处理数据点 " + (i + 1) + ": x = 
                    " + Arrays.toString(data[i].x) + ", y = " + data[i].y);
                    perceptron.train(data[i]);
                    // 预测当前点,验证结果
                    int prediction = perceptron.predict(data[i].x);
                    System.out.println("预测结果:y = " + prediction);
                
                // 输出最终权重
                System.out.println("\n训练完成,最终权重:w = " + 
                Arrays.toString(perceptron.getWeights()));
            }
        }
    }